'''Train CIFAR10 with PyTorch.'''
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os
import argparse
from models.models import *
from utils import progress_bar
import numpy as np
from cifar10_input import *
from spatial_attacker import *

# Main training loop
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0.0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        
        inputs, targets = inputs.to(device), targets.to(device)
        lenx = inputs.shape[0]
        with torch.no_grad():
            rotated,_ = SA(inputs*maskin,targets,net)   
            
        mix_batch = torch.cat([inputs,rotated],0)
        
        optimizer.zero_grad()
        
        outs_mix = net(mix_batch*maskin)
        outnat, outadv = torch.split(outs_mix,[lenx,lenx])
        if args.train == 'natural':
            loss1 = criterion(outnat, targets)
            _, predicted = outnat.max(1)
            correct += predicted.eq(targets).sum().item()
        elif args.train == 'robust':
            loss1 = criterion(outadv, targets)
            _, predicted = outadv.max(1)
            correct += predicted.eq(targets).sum().item()
        elif args.train == 'mixed':
            loss1 = criterion(outnat, targets)+criterion(outadv, targets)
            _, predicted1 = outnat.max(1)
            _, predicted2 = outadv.max(1)
            correct += 0.5*(predicted1.eq(targets).sum().item()+predicted2.eq(targets).sum().item())
            
        reg_term = 0.0
        if args.reg is not None:
            reg_term = reg(outnat,outadv)
        #print('loss1',loss1.item())
        #print('reg',reg_term.item())
        loss = (loss1) +(args.lamda*reg_term)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        
        total += targets.size(0)
        

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))


def valid(epoch):
    global best_acc
    global best_loss
    net.eval()
    test_loss = 0
    correct = 0.0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(validloader):
            inputs, targets = inputs.to(device), targets.to(device)
            lenx =inputs.shape[0]
            
            rotated,_ = SA(inputs*maskin,targets,net)   
            mix_batch = torch.cat([inputs,rotated],0)
            outs_mix = net(mix_batch*maskin)
            
            outnat, outadv = torch.split(outs_mix,[lenx,lenx])
            if args.train == 'natural':
                loss1 = criterion(outnat, targets)
                _, predicted = outnat.max(1)
                correct += predicted.eq(targets).sum().item()
            elif args.train == 'robust':
                loss1 = criterion(outadv, targets)
                _, predicted = outadv.max(1)
                correct += predicted.eq(targets).sum().item()
            elif args.train == 'mixed':
                loss1 = criterion(outnat, targets)+criterion(outadv, targets)
                _, predicted1 = outnat.max(1)
                _, predicted2 = outadv.max(1)
                correct += 0.5*(predicted1.eq(targets).sum().item()+predicted2.eq(targets).sum().item())
            reg_term = 0.0
            
            if args.reg is not None:
                reg_term = reg(outnat,outadv)
            loss = (loss1) +(args.lamda*reg_term)
            
            test_loss += loss.item()
            
            total += targets.size(0)
            

        progress_bar(batch_idx, len(validloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc or test_loss<best_loss:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
            'loss':test_loss,
        }
        torch.save(state, chckpt)
        if acc> best_acc:
            best_acc = acc
        if test_loss<best_loss:
            best_loss = test_loss
    return test_loss

def test():
    checkpoint = torch.load(chckpt)
    net.load_state_dict(checkpoint['net'])
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    acc = 100.*correct/total
    print('Test accuracy', acc)
    return acc
def make_grid():
    torch.pi = torch.acos(torch.zeros(1)).item() * 2
    rad = 2*torch.pi*torch.arange(0.0,360,1.0)/360
    cose = torch.cos(rad).unsqueeze(0)
    sine = torch.sin(rad).unsqueeze(0)
    zer0s = torch.zeros_like(cose)
    rotmats = torch.cat([torch.cat([cose,-sine,zer0s],0).unsqueeze(0),
                     torch.cat([sine,cose,zer0s],0).unsqueeze(0)],0).permute(2,0,1).to(device)
    grid = F.affine_grid(rotmats, (360,3,32,32)) 
    return grid

def adversarial_test(net,interp_mode,grid): 
    net.eval()
    counts = 0
    correct = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            inputs=inputs.repeat(360,1,1,1)            
            inputs = F.grid_sample(inputs, grid, mode=interp_mode, padding_mode='zeros')

            outs = net(inputs)
            
            _, preds = outs.max(1)
            counts += torch.sum(torch.sum(preds!=targets)>0).item()
            correct += torch.sum(preds==targets).item()
            
        rot_acc = np.true_divide(correct,360*len(testset))*100.   
        counts = np.true_divide(counts,len(testset))*100.
        adv_acc = 100.-counts
    return rot_acc,adv_acc
     
def make_mask(inshape):
    ny,nx = inshape
    center = (int(ny/2), int(nx/2))
    radius = min(center[0], center[1], ny-center[0], nx-center[1])
    Y, X = torch.meshgrid(torch.arange(0.0,nx), torch.arange(0.0,ny))
    dist_from_center = torch.sqrt((X - center[0])**2 + (Y-center[1])**2)
    mask = dist_from_center <= radius
                   
    return mask

######Argument Parser####
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')

parser.add_argument('--dataset',default='CIFAR10')
parser.add_argument('--augment', default='standard', help='data augmentation-- standard, rot-default, combi-rot, combi-rot-no-crop')
parser.add_argument('--train_batch_size',default=128,type=int)
parser.add_argument('--val_batch_size',default=100,type=int)

parser.add_argument('--rotation_range',default=180.0, type=float,help='range of rotation angles,(-180,180]')
parser.add_argument('--granularity',default=360, type=int,help='number of angles for grid search')
parser.add_argument('--train', default='robust', help='training-- natural,robust,mixed')
parser.add_argument('--adv_criterion', default ='xent', help='criterion for adversarial example generation-- xent')
parser.add_argument('--reg', default= None, help='additional regularization, None, KL, ALP')
parser.add_argument('--lamda', default=1.0, type=float, help='regularization strength')
parser.add_argument('--k', default=10,type=int, help='compute adversarial example from worst of k, k=1 for random')
parser.add_argument('--mask', action='store_true',help='Mask during training')
parser.add_argument('--run',default=None,type=int)
parser.add_argument('--net', default='Resnet18',help='network-options--linear,convnet,Resnet18')
parser.add_argument('--chkpt_dir', default = './checkpoint/',  help='Check point directory')

args = parser.parse_args()
print(args)
chckpt = f'{args.chkpt_dir}{args.dataset}.{args.augment}.{args.train}.reg-{args.reg}.lamda-{args.lamda:.2f}.{args.net}.run-{args.run}.pth'

print(chckpt)
device = 'cuda:{}'.format(0)
best_acc = 0  # best val accuracy
best_loss = 1000.0
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
end_epoch = 150


# Data loaders
print('==> Preparing data..')
trainloader,validloader = make_dataset(args)

# Model
print('==> Building model..')
net = get_model(args.net)
net = net.to(device)
net = torch.nn.DataParallel(net)
cudnn.benchmark = True

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr,
                      momentum=0.9, weight_decay=5e-4)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5,verbose=True)
#scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.25, last_epoch=-1, verbose=False)#
if args.resume:
    print('==> Resuming from checkpoint..')
    checkpoint = torch.load(chckpt)
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']
    
# Attacker
adv= True
SA = SpatialAttacker('random',args,adv)
SA = SA.to (device)

# Mask
inshape = [32,32]
if args.mask:
    maskin = make_mask(inshape)
else:
    maskin = torch.ones(inshape)
maskin = maskin.to(device)
maskin = make_mask([32,32]).to(device)

#Regularizer
if args.reg is not None:
    reg = get_regularizer(args.reg)

###Training loop
for epoch in range(start_epoch, end_epoch):
    train(epoch)
    val_loss=valid(epoch)
    scheduler.step(val_loss)

#####Standard Accuracy
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)
print('test_normal')
std_acc = test()

fname=f'{args.net}.{args.train}.reg-{args.reg}.lamda-{args.lamda:.2f}.run-{args.run}.txt'
f=open(fname, "w+")
f.write(f'\n{chckpt}\n')
f.write(f'Std_acc:{std_acc}\n')

####Evaluation on Grid
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=2)
interps = ['bilinear','nearest','bicubic']  
grid = make_grid()
checkpoint = torch.load(chckpt)
net.load_state_dict(checkpoint['net'])
net.eval() 
for loader in interps:
    rot_acc,adv_acc,= adversarial_test(net,loader,grid)
    f.write(f'\n{loader}:\t{rot_acc}\t{adv_acc}')
f.close()

